package com.lw.leetcode.tree.b;

import java.util.ArrayList;
import java.util.List;

/**
 * Created with IntelliJ IDEA.
 *
 * @author liw
 * @version 1.0
 * @date 2022/11/14 14:34
 */
public class MostProfitablePath {

    public static void main(String[] args) {
        MostProfitablePath test = new MostProfitablePath();

        // 6
        int[][] edges = {{0, 1}, {1, 2}, {1, 3}, {3, 4}};
        int bob = 3;
        int[] amount = {-2, 4, 2, -4, 6};

        int i = test.mostProfitablePath(edges, bob, amount);
        System.out.println(i);

    }

    private Node[] arr;
    private List<Integer>[] items;
    private int max = Integer.MIN_VALUE;

    public int mostProfitablePath(int[][] edges, int bob, int[] amount) {
        int length = amount.length;
        arr = new Node[length];
        items = new ArrayList[length];
        for (int i = 0; i < length; i++) {
            arr[i] = new Node(i, amount[i]);
            items[i] = new ArrayList<>();
        }
        for (int[] edge : edges) {
            int a = edge[0];
            int b = edge[1];
            items[a].add(b);
            items[b].add(a);
        }
        find(-1, arr[0]);
        find(arr[bob], 0);
        getMax(arr[0], 0, 1);
        return max;
    }

    private void getMax(Node node, int sum, int step) {
        if (node.s == step) {
            sum += (node.amount >> 1);
        } else if (node.s == 0 || node.s > step) {
            sum += node.amount;
        }
        List<Node> li = node.list;
        if (li.isEmpty()) {
            max = Math.max(max, sum);
            return;
        }
        for (Node no : li) {
            getMax(no, sum, step + 1);
        }
    }

    private void find(Node node, int s) {
        if (node == null) {
            return;
        }
        node.s = s + 1;
        find(node.pre, s + 1);
    }

    private void find(int p, Node node) {
        List<Integer> item = items[node.val];
        if (item.isEmpty()) {
            return;
        }
        List<Node> list = node.list;
        for (Integer v : item) {
            if (v == p) {
                continue;
            }
            Node no = arr[v];
            no.pre = node;
            list.add(no);
            find(node.val, no);
        }
    }

    public static class Node {
        private int val;
        private Node pre;
        private int s = 0;
        private List<Node> list = new ArrayList<>();
        private int amount;

        public Node(int val, int amount) {
            this.val = val;
            this.amount = amount;
        }
    }

}
